using Flux
using ForwardDiff
using LinearAlgebra
using Distributions
using Roots
using Plots
using LaTeXStrings
using ProgressMeter
using GLMakie
using Random

Random.seed!(1234)

# Feature dimension
d = 2

# Number of labels
K = 6

# Confidence level
# α = 0.05
α = 0.1

# GMM params β, μ, Σ such that P(Y=k) = βₖ, X|Y=k ∼ 𝓝(μₖ, Σₖ)
function generate_GMM(K, d)
    # β = rand(K)
    # β /= sum(β)
    β = [1/K for k in 1:K]
    # μ = [20*randn(d) for k in 1:K]
    μ = [rand(Uniform(-100, 100), d) for k in 1:K]
    Σ = [15*randn(d, d) for k in 1:K]
    Σ = [Σ_k * Σ_k' + I(d)*25.0 for Σ_k in Σ]
    return β, μ, Σ
end
β, μ, Σ = generate_GMM(K, d)

function simulate_GMM(β, μ, Σ, n)
    d = length(μ[1])
    Y = rand(Distributions.Categorical(β), n)
    X = zeros(n, d)
    for i in 1:n
        k = Y[i]
        X[i, :] = rand(MvNormal(μ[k], Σ[k]))
    end

    return X, Y
end

function plot_GMM(β, μ, Σ, n)
    X, Y = simulate_GMM(β, μ, Σ, n)
    scatterplot = Plots.scatter(X[:,1], X[:,2], group=Y, legend=false)
    display(scatterplot)
    savefig(scatterplot, "GMM.pdf")
end

function plot_GMM_3d(β, μ, Σ, n)
    X, Y = simulate_GMM(β, μ, Σ, n)
    if d == 1
        X = hcat(X, zeros(n), zeros(n))
    elseif d == 2
        X = hcat(X, zeros(n))
    end
    colors = theme_palette(:auto)[1:K] # This assumes K ≤ 16
    scatterplot = GLMakie.scatter(X[:,1], X[:,2], X[:,3], colormap=colors, color=Y)
    display(scatterplot)
end
plot_GMM(β, μ, Σ, 2500)
plot_GMM_3d(β, μ, Σ, 2500)

# ANN model params W,b using LDA
W = zeros(K, d)
b = zeros(K)
for k in 1:K
    W[k, :] = inv(Σ[k]) * μ[k]
    b[k] = -0.5 * μ[k]' * inv(Σ[k]) * μ[k] - log(β[k])
end
θ = (W, b)

# ANN model function (1-layer MLP w/ sigmoid activation)
function f(x, θ)
    W, b = θ
    return σ.(W * x + b)
end

# Conformity score (THR)
E(θ, x, y) = f(x, θ)[y]

# Some utility functions (inverse of sigmoid and standard normal CDF)
σ_inv(t) = log(t / (1 - t))
Φ(t) = cdf(Normal(0, 1), t)

# Auxiliary functions to compute τ and its gradient (see notes)
function h(τ, θ)
    τ = clamp(τ, 0, 1)
    W, b = θ
    
    h = -α
    for k in 1:K
        num = σ_inv(τ) - b[k] - W[k, :]' * μ[k]
        den = sqrt(W[k, :]' * Σ[k] * W[k, :])
        h += β[k] * Φ(num / den)
    end

    return h
end
h_τ(τ, θ) = Flux.gradient(τ -> h(τ, θ), τ)[1]
h_θ(τ, θ) = Flux.gradient(θ -> h(τ, θ), θ)[1]

# Ground truth τ and its gradient
τ(θ) = find_zero(τ -> h(τ, θ), 0.5) # 0.5 is a good initial guess due to the sigmoid activation
∇τ(θ) = h_θ(τ(θ), θ) ./ (-h_τ(τ(θ), θ))

# ConfTr estimate of τ
function τ_hat(θ, X, Y)
    n = length(Y)
    scores = [E(θ, X[i, :], Y[i]) for i in 1:n]
    return quantile(scores, α)
end

# ConfTr estimate of ∇τ
∇τ_hat(θ, X, Y) = (
    ForwardDiff.gradient(W -> τ_hat((W,b), X, Y), θ[1]),
    ForwardDiff.gradient(b -> τ_hat((W,b), X, Y), θ[2])
)

# VR-ConfTr estimate of τ using top m samples
function ∇τ_hat_VR(θ, X, Y, m)
    n = length(Y)
    m = min(m, n) |> ceil |> Int
    q = τ_hat(θ, X, Y)
    ind = [abs(E(θ, X[i, :], Y[i]) - q) for i in 1:n] |> sortperm
    ind = ind[1:m]
    return Flux.gradient(θ -> mean(E(θ, X[i, :], Y[i]) for i in ind), θ)[1]
end

function main()
    η = ∇τ(θ)
    η_W = η[1]
    η_b = η[2]

    n_runs = 200
    # n_max = 1000
    n_max = 1000

    bias_W = zeros(n_max)
    bias_b = zeros(n_max)
    bias_W_VR = zeros(n_max)
    bias_b_VR = zeros(n_max)

    variance_W = zeros(n_max)
    variance_b = zeros(n_max)
    variance_W_VR = zeros(n_max)
    variance_b_VR = zeros(n_max)
    
    η_hat_runs = []
    @showprogress for n in 1:n_max
        η_W_hat = zeros(n_runs, K, d)
        η_b_hat = zeros(n_runs, K)
        η_W_hat_VR = zeros(n_runs, K, d)
        η_b_hat_VR = zeros(n_runs, K)
        for run in 1:n_runs
            X, Y = simulate_GMM(β, μ, Σ, n)

            η_hat = ∇τ_hat(θ, X, Y)
            η_W_hat[run, :, :] = η_hat[1]
            η_b_hat[run, :] = η_hat[2]

            # m = n/2
            # m = sqrt(n)
            # m = n^(1/4)
            # m = α * n / log(n+1)
            # m = α * n^(1-2α)

            # m = max(1, α * n / log(log(max(2,n))))

            m = max(1, α * n / log(log(max(2,n))))
            p = m - floor(m)
            if rand() < p
                m = ceil(m)
            else
                m = floor(m)
            end

            η_hat_VR = ∇τ_hat_VR(θ, X, Y, m)
            
            η_W_hat_VR[run, :, :] = η_hat_VR[1]
            η_b_hat_VR[run, :] = η_hat_VR[2]
        end
        η_W_hat_mean = mean(η_W_hat, dims=1)[1, :, :]
        η_b_hat_mean = mean(η_b_hat, dims=1)[1, :]

        η_W_hat_mean_VR = mean(η_W_hat_VR, dims=1)[1, :, :]
        η_b_hat_mean_VR = mean(η_b_hat_VR, dims=1)[1, :]

        bias_W[n] = norm(η_W - η_W_hat_mean)
        bias_b[n] = norm(η_b - η_b_hat_mean)

        bias_W_VR[n] = norm(η_W - η_W_hat_mean_VR)
        bias_b_VR[n] = norm(η_b - η_b_hat_mean_VR)

        variance_W[n] = mean(norm(η_W_hat[i, :, :][:] - η_W_hat_mean[:])^2 for i in 1:n_runs)
        variance_b[n] = mean(norm(η_b_hat[i, :][:] - η_b_hat_mean[:])^2 for i in 1:n_runs)

        variance_W_VR[n] = mean(norm(η_W_hat_VR[i, :, :][:] - η_W_hat_mean_VR[:])^2 for i in 1:n_runs)
        variance_b_VR[n] = mean(norm(η_b_hat_VR[i, :][:] - η_b_hat_mean_VR[:])^2 for i in 1:n_runs)
    end

    # plot_bias_W = Plots.plot(1:n_max, bias_W, label="ConfTr")
    # Plots.plot!(plot_bias_W, 1:n_max, bias_W_VR, label="VR-ConfTr")
    plot_bias_W = Plots.plot(1:n_max, bias_W, yscale=:log10, label="ConfTr")
    Plots.plot!(plot_bias_W, 1:n_max, bias_W_VR, yscale=:log10, label="VR-ConfTr")
    Plots.xlabel!("batch size")
    Plots.ylabel!("bias")
    
    plot_variance_W = Plots.plot(1:n_max, variance_W, yscale=:log10, label="ConfTr")
    Plots.plot!(plot_variance_W, 1:n_max, variance_W_VR, yscale=:log10, label="VR-ConfTr")
    # plot_variance_W = Plots.plot(1:n_max, variance_W, label="ConfTr")
    # Plots.plot!(plot_variance_W, 1:n_max, variance_W_VR, label="VR-ConfTr")
    Plots.xlabel!("batch size")
    Plots.ylabel!("variance")

    Plots.plot(plot_bias_W, plot_variance_W, layout=(2,1), leguend=:outerright)

    savefig("bias_variance.pdf")

end

main()